Analyse predictons of seasonal surface DOC

Model evaluation, interpretation & projections

Author

Thelma Panaïotis

Set-up and load data

source("utils.R")
load("data/07.doc_surf_seas_pred.Rdata")

Model evaluation

Rsquares

Black dots on the R² boxplots show the actual values.

# Unnest predictions
preds <- res %>% select(fold, season, preds) %>% unnest(preds)

# Compute Rsquare for each fold of each CV type
rsquares <- preds %>%
  group_by(season, fold) %>%
  rsq(truth = log_doc_surf, estimate = .pred)

# Distribution of Rsquares by season
rsquares %>% split(.$season) %>% map(summary)
$`1`
    season              fold             .metric           .estimator       
 Length:10          Length:10          Length:10          Length:10         
 Class :character   Class :character   Class :character   Class :character  
 Mode  :character   Mode  :character   Mode  :character   Mode  :character  
                                                                            
                                                                            
                                                                            
   .estimate     
 Min.   :0.7783  
 1st Qu.:0.8121  
 Median :0.8285  
 Mean   :0.8303  
 3rd Qu.:0.8520  
 Max.   :0.8940  

$`2`
    season              fold             .metric           .estimator       
 Length:10          Length:10          Length:10          Length:10         
 Class :character   Class :character   Class :character   Class :character  
 Mode  :character   Mode  :character   Mode  :character   Mode  :character  
                                                                            
                                                                            
                                                                            
   .estimate     
 Min.   :0.4867  
 1st Qu.:0.6505  
 Median :0.7906  
 Mean   :0.7459  
 3rd Qu.:0.8446  
 Max.   :0.8958  

$`3`
    season              fold             .metric           .estimator       
 Length:10          Length:10          Length:10          Length:10         
 Class :character   Class :character   Class :character   Class :character  
 Mode  :character   Mode  :character   Mode  :character   Mode  :character  
                                                                            
                                                                            
                                                                            
   .estimate     
 Min.   :0.4037  
 1st Qu.:0.5679  
 Median :0.6516  
 Mean   :0.6237  
 3rd Qu.:0.7150  
 Max.   :0.7722  

$`4`
    season              fold             .metric           .estimator       
 Length:10          Length:10          Length:10          Length:10         
 Class :character   Class :character   Class :character   Class :character  
 Mode  :character   Mode  :character   Mode  :character   Mode  :character  
                                                                            
                                                                            
                                                                            
   .estimate     
 Min.   :0.5235  
 1st Qu.:0.6418  
 Median :0.6817  
 Mean   :0.6897  
 3rd Qu.:0.7499  
 Max.   :0.7980  
# Plot Rsquares values
ggplot(rsquares) + 
  geom_boxplot(aes(x = season, y = .estimate, group = season, colour = season)) +
  geom_jitter(aes(x = season, y = .estimate), size = 0.5, width = 0.1) +
  scale_y_continuous(limits = c(0, 1), expand = c(0, 0)) +
  labs(x = "Season", y = "R²", colour = "Season")

Predictions VS truth

Plot pred VS truth on the test part of each fold of each CV type.

preds %>%
  ggplot() +
  geom_point(aes(x = log_doc_surf, y = .pred, colour = season)) +
  geom_abline(intercept = 0, slope = 1, colour = "red") +
  coord_fixed() + 
  facet_wrap(season~fold, ncol = 5)

Now let’s focus on a representative fold for each CV type.

# Find the one closer to the median and plot it
repres_fold <- rsquares %>%
  group_by(season) %>%
  mutate(diff = abs(.estimate - median(.estimate))) %>%
  filter(diff == min(diff)) %>%
  slice_head(n = 1)

repres_fold %>%
  select(season, fold) %>%
  left_join(preds, by = join_by(season, fold)) %>%
  ggplot() +
  geom_point(aes(x = log_doc_surf, y = .pred, colour = season)) +
  geom_abline(intercept = 0, slope = 1, colour = "red") +
  coord_fixed() + 
  labs(title = "Pred VS truth for a representative fold") +
  facet_wrap(~season, ncol = 2)

Model interpretation

Variable importance

Variable importance for each fold of each CV type.

# Unnest variable importance
full_vip <- res %>%
  select(season, fold, importance) %>%
  unnest(importance) %>%
  mutate(variable = forcats::fct_reorder(variable, dropout_loss))

# Variable importance across folds
full_vip %>%
  filter(variable != "_full_model_") %>%
  ggplot() +
  geom_vline(data = full_vip %>% filter(variable == "_full_model_"), aes(xintercept = mean(dropout_loss)), colour = "grey", linewidth = 2) +
  geom_boxplot(aes(x = dropout_loss, y = variable, colour = season)) +
  labs(x = "RMSE after permutations") +
  facet_grid(fold~season)

Now let’s take the mean across folds of each CV type.

full_vip %>%
  filter(variable != "_full_model_") %>%
  group_by(season, fold, variable) %>%
  summarise(dropout_loss = mean(dropout_loss), .groups = "drop") %>%
  ggplot() +
  geom_vline(data = full_vip %>% filter(variable == "_full_model_"), aes(xintercept = mean(dropout_loss)), colour = "grey", linewidth = 2) +
  geom_boxplot(aes(x = dropout_loss, y = variable, colour = season)) +
  labs(x = "Mean RMSE after permutations across CV folds") +
  facet_wrap(~season, ncol = 4)

Partial dependence plots

Finally, let’s have a look at partial dependence plots.

  • blue line: prediction mean across cp profiles

  • grey ribbon: prediction sd across centered cp profiles

# Variables for which to plot pdp
n_pdp <- 3
vars_pdp <- full_vip %>%
  filter(variable != "_full_model_") %>%
  mutate(variable = as.character(variable)) %>%
  group_by(season, variable) %>%
  summarise(dropout_loss = mean(dropout_loss), .groups = "drop") %>%
  arrange(desc(dropout_loss)) %>%
  group_by(season) %>%
  slice_head(n = n_pdp)

# Unnest cp_profiles
cp_profiles <- res %>% select(season, fold, cp_profiles) %>% unnest(cp_profiles)

## Let’s generate averaged cp profile across folds for each cv-type and propagating uncertainties. 
## The difficulty is that x values differ between each fold, the solution is to interpolate yhat on a common set of x values across folds.
## Steps as follows for each season and each variable
## 1- compute the mean and spread of cp profiles within each fold
## 2- interpolate yhat value and spread within each fold using a common set of x values
## 3- perform a weighted average of yhat value and spread, using 1/var as weights

# Get names of folds, for later use
folds <- sort(unique(full_vip$fold))

# Apply on each season and variable
mean_pdp <- lapply(1:nrow(vars_pdp), function(r){
  
  # Get variable and cvtype
  var_name <- vars_pdp[r,]$variable
  season_name <- vars_pdp[r,]$season
  
  ## Get corresponding CP profiles, compute mean and spread for each fold (step 1)
  d_pdp <- cp_profiles %>% 
    filter(season == season_name & `_vname_` == var_name) %>% 
    select(season, fold, `_yhat_`, `_vname_`, `_ids_`, all_of(var_name)) %>% 
    arrange(`_ids_`, across(all_of(var_name))) %>% 
    # center each cp profiles across fold, variable and ids
    group_by(season, fold, `_vname_`, `_ids_`) %>%
    mutate(yhat_cent = `_yhat_` - mean(`_yhat_`)) %>% # center cp profiles
    ungroup() %>%
    # compute mean and sd of centered cp profiles for each fold and value of the variable of interest
    group_by(season, fold, across(all_of(var_name))) %>%
    summarise(
      yhat_loc = mean(`_yhat_`), # compute mean of profiles
      yhat_spr = sd(yhat_cent), # compute sd of cp profiles
      .groups = "keep"
    ) %>%
    ungroup() %>% 
    setNames(c("season", "fold", "x", "yhat_loc", "yhat_spr"))
  
  ## Interpolate yhat values and spread on a common x distribution (step 2)
  # Regularise across folds: need a common x distribution, and interpolate y on this new x
  new_x <- quantile(d_pdp$x, probs = seq(0, 1, 0.01), names = FALSE)
  # x is different within each fold, so interpolation is performed on each fold
  
  int_pdp <- lapply(1:length(folds), function(i){
    # Get data corresponding to this fold
    fold_name <- folds[i]
    this_fold <- d_pdp %>% filter(fold == fold_name)
    
    # Extract original x values
    x <- this_fold$x
    # Extract values to interpolate (yhat_loc and yhat_spr)
    yhat_loc <- this_fold$yhat_loc
    yhat_spr <- this_fold$yhat_spr
    # Interpolate yhat_loc and yhat_spr on new x values
    int <- tibble(
      x = new_x,
      yhat_loc = castr::interpolate(x = x, y = yhat_loc, xout = new_x),
      yhat_spr = castr::interpolate(x = x, y = yhat_spr, xout = new_x),
    ) %>% 
      mutate(
        season = season_name,
        fold = fold_name,
        var_name = var_name,
        .before = x
        )
    # Return the result
    return(int)
    
  }) %>% 
    bind_rows()
  
  ## Across fold, compute the weighted mean, using 1/var as weights (step 3)
  mean_pdp <- int_pdp %>% 
    group_by(season, var_name, x) %>% 
    summarise(
      yhat_loc = wtd.mean(yhat_loc, weights = 1/(yhat_spr)^2),
      yhat_spr = wtd.mean(yhat_spr, weights = 1/(yhat_spr)^2),
      .groups = "drop"
    ) %>% 
    arrange(x)
  
  # Return the result
  return(mean_pdp)
}) %>% 
  bind_rows()

# Arrange in order of most important variables
mean_pdp <- vars_pdp %>% 
  rename(var_name = variable) %>% 
  left_join(mean_pdp, by = join_by(season, var_name)) %>% 
  mutate(var_name = fct_inorder(var_name)) %>% 
  select(-dropout_loss)

# Plot it!
ggplot(mean_pdp) + 
  geom_path(aes(x = x, y = yhat_loc, colour = season)) +
  geom_ribbon(aes(x = x, ymin = yhat_loc - yhat_spr, ymax = yhat_loc + yhat_spr, fill = season), alpha = 0.2) +
  facet_wrap(~var_name, scales = "free_x")

New predictions

Collect predictions

# Unnest new predictions (i.e. projections)
new_preds <- res %>% 
  select(fold, cv_type, new_preds) %>% 
  unnest(new_preds) %>% 
  # Apply exp to predictions as we predicted log(doc)
  mutate(pred_doc = exp(pred_doc_log), .after = pred_doc_log) %>% 
  select(season, fold, lon, lat, contains("doc")) %>%
  mutate(season = as.character(season))

# Join projections with R² value for each fold and each season.
new_preds_strat <- new_preds %>% left_join(rsquares %>% select(season, fold, rsq = .estimate), by = join_by(season, fold))

## Average by pixel
# Stratified
strat_proj <- new_preds_strat %>% 
  group_by(lon, lat, season) %>% 
  summarise(
    doc_avg = wtd.mean(pred_doc, weights = rsq, na.rm = TRUE), 
    doc_sd = sqrt(wtd.var(pred_doc, weights = rsq, na.rm = TRUE)), 
    .groups = "drop"
    )

# Generate common colour bar limits for both CV types
doc_avg_lims <- c(
  min(c(strat_proj$doc_avg)), 
  max(c(strat_proj$doc_avg))  
)
doc_sd_lims <- c(
  min(c(strat_proj$doc_sd)), 
  max(c(strat_proj$doc_sd))  
)

Maps

Stratified CV

ggplot(strat_proj) + 
  geom_polygon(data = world, aes(x = lon, y = lat, group = group), fill = "grey") +
  geom_raster(aes(x = lon, y = lat, fill = doc_avg)) + 
  ggplot2::scale_fill_viridis_c(option = "F", limits = doc_avg_lims, trans = "log1p") +
  labs(title = "DOC avg from stratified CV") +
  coord_quickmap(expand = 0) +
  facet_wrap(~season)

ggplot(strat_proj) + 
  geom_polygon(data = world, aes(x = lon, y = lat, group = group), fill = "grey") +
  geom_raster(aes(x = lon, y = lat, fill = doc_sd)) + 
  ggplot2::scale_fill_viridis_c(option = "E", limits = doc_sd_lims, trans = "log1p") +
  labs(title = "DOC sd from stratified CV") +
  coord_quickmap(expand = 0) +
  facet_wrap(~season)

Seasonal cycle

seas_amp <- strat_proj %>% 
  group_by(lon, lat) %>% 
  summarise(
    seas_amp = max(doc_avg, na.rm = TRUE) - min(doc_avg, na.rm = TRUE), 
    seas_var = var(doc_avg, na.rm = TRUE),
    .groups = "drop"
    )

ggplot(seas_amp) + 
  geom_polygon(data = world, aes(x = lon, y = lat, group = group), fill = "grey") +
  geom_raster(aes(x = lon, y = lat, fill = seas_amp)) + 
  scale_fill_viridis_c() +
  labs(title = "Seasonal amplitude") +
  coord_quickmap(expand = 0) 

ggplot(seas_amp) + 
  geom_polygon(data = world, aes(x = lon, y = lat, group = group), fill = "grey") +
  geom_raster(aes(x = lon, y = lat, fill = seas_var)) + 
  scale_fill_viridis_c(trans = "log1p", na.value = NA) +
  labs(title = "Seasonal variance") +
  coord_quickmap(expand = 0)